from dataclasses import dataclass

from ..utils import get_aligned_idx
from ..utils.typing import EmbeddingsTensor, ExpandedTargetIdsTensor, IdsTensor, OneOrMoreTokenSequences
from .data_utils import TensorWrapper


@dataclass(eq=False, repr=False)
class BatchEncoding(TensorWrapper):
    """Output produced by the tokenization process using :meth:`~inseq.models.AttributionModel.encode`.

    Attributes:
        input_ids (:obj:`torch.Tensor`): Batch of token ids with shape ``[batch_size, longest_seq_length]``.
            Extra tokens for each sentence are padded, and truncation to ``max_seq_length`` is performed.
        input_tokens (:obj:`list(list(str))`): List of lists containing tokens for each sentence in the batch.
        attention_mask (:obj:`torch.Tensor`): Batch of attention masks with shape ``[batch_size, longest_seq_length]``.
            1 for positions that are valid, 0 for padded positions.
        baseline_ids (torch.Tensor, optional): Batch of reference token ids with shape
            ``[batch_size, longest_seq_length]``. Used for attribution methods requiring a baseline input (e.g. IG).
    """

    input_ids: IdsTensor
    attention_mask: IdsTensor
    input_tokens: OneOrMoreTokenSequences | None = None
    baseline_ids: IdsTensor | None = None

    def __len__(self) -> int:
        return len(self.input_tokens)

    @property
    def num_sequences(self) -> int:
        return self.input_ids.shape[0]


@dataclass(eq=False, repr=False)
class BatchEmbedding(TensorWrapper):
    """Embeddings produced by the embedding process using :meth:`~inseq.models.AttributionModel.embed`.

    Attributes:
        input_embeds (:obj:`torch.Tensor`): Batch of token embeddings with shape
            ``[batch_size, longest_seq_length, embedding_size]`` for each sentence in the batch.
        baseline_embeds (:obj:`torch.Tensor`, optional): Batch of reference token embeddings with shape
            ``[batch_size, longest_seq_length, embedding_size]`` for each sentence in the batch.
    """

    input_embeds: EmbeddingsTensor | None = None
    baseline_embeds: EmbeddingsTensor | None = None

    def __len__(self) -> int | None:
        if self.input_embeds is not None:
            return self.input_embeds.shape[0]
        return None


@dataclass(eq=False, repr=False)
class Batch(TensorWrapper):
    """Batch of input data for the attribution model.

    Attributes:
        encoding (:class:`~inseq.data.BatchEncoding`): Output produced by the tokenization process using
            :meth:`~inseq.models.AttributionModel.encode`.
        embedding (:class:`~inseq.data.BatchEmbedding`): Embeddings produced by the embedding process using
            :meth:`~inseq.models.AttributionModel.embed`.

    All attribute fields are accessible as properties (e.g. ``batch.input_ids`` corresponds to
        ``batch.encoding.input_ids``)
    """

    encoding: BatchEncoding
    embedding: BatchEmbedding

    @property
    def input_ids(self) -> IdsTensor:
        return self.encoding.input_ids

    @property
    def input_tokens(self) -> OneOrMoreTokenSequences:
        return self.encoding.input_tokens

    @property
    def attention_mask(self) -> IdsTensor:
        return self.encoding.attention_mask

    @property
    def baseline_ids(self) -> IdsTensor | None:
        return self.encoding.baseline_ids

    @property
    def input_embeds(self) -> EmbeddingsTensor | None:
        return self.embedding.input_embeds

    @property
    def baseline_embeds(self) -> EmbeddingsTensor | None:
        return self.embedding.baseline_embeds

    @input_ids.setter
    def input_ids(self, value: IdsTensor):
        self.encoding.input_ids = value

    @input_tokens.setter
    def input_tokens(self, value: list[list[str]]):
        self.encoding.input_tokens = value

    @attention_mask.setter
    def attention_mask(self, value: IdsTensor):
        self.encoding.attention_mask = value

    @baseline_ids.setter
    def baseline_ids(self, value: IdsTensor | None):
        self.encoding.baseline_ids = value

    @input_embeds.setter
    def input_embeds(self, value: EmbeddingsTensor | None):
        self.embedding.input_embeds = value

    @baseline_embeds.setter
    def baseline_embeds(self, value: EmbeddingsTensor | None):
        self.embedding.baseline_embeds = value


@dataclass(eq=False, repr=False)
class EncoderDecoderBatch(TensorWrapper):
    """Batch of input data for the encoder-decoder attribution model, including information for the source text and the
    target prefix.

    Attributes:
        sources (:class:`~inseq.data.Batch`): Batch of input data for the source text.
        targets (:class:`~inseq.data.Batch`): Batch of input data for the target prefix.
    """

    sources: Batch
    targets: Batch

    def __getitem__(self, subscript: slice | int) -> "EncoderDecoderBatch":
        return EncoderDecoderBatch(sources=self.sources, targets=self.targets[subscript])

    @property
    def max_generation_length(self) -> int:
        return self.targets.input_ids.shape[1]

    @property
    def source_tokens(self) -> OneOrMoreTokenSequences:
        return self.sources.input_tokens

    @property
    def target_tokens(self) -> OneOrMoreTokenSequences:
        return self.targets.input_tokens

    @property
    def source_ids(self) -> IdsTensor:
        return self.sources.input_ids

    @property
    def target_ids(self) -> IdsTensor:
        return self.targets.input_ids

    @property
    def source_embeds(self) -> EmbeddingsTensor:
        return self.sources.input_embeds

    @property
    def target_embeds(self) -> EmbeddingsTensor:
        return self.targets.input_embeds

    @property
    def source_mask(self) -> IdsTensor:
        return self.sources.attention_mask

    @property
    def target_mask(self) -> IdsTensor:
        return self.targets.attention_mask

    def get_step_target(
        self, step: int, with_attention: bool = False
    ) -> ExpandedTargetIdsTensor | tuple[ExpandedTargetIdsTensor, ExpandedTargetIdsTensor]:
        tgt = self.targets.input_ids[:, step]
        if with_attention:
            return tgt, self.targets.attention_mask[:, step]
        return tgt


@dataclass(eq=False, repr=False)
class DecoderOnlyBatch(Batch):
    """Input batch adapted for decoder-only attribution models, including information for the target prefix."""

    @property
    def max_generation_length(self) -> int:
        return self.input_ids.shape[1]

    @property
    def source_tokens(self) -> OneOrMoreTokenSequences:
        return None

    @property
    def target_tokens(self) -> OneOrMoreTokenSequences:
        return self.input_tokens

    @property
    def source_ids(self) -> IdsTensor:
        return None

    @property
    def target_ids(self) -> IdsTensor:
        return self.input_ids

    @property
    def source_embeds(self) -> EmbeddingsTensor:
        return None

    @property
    def target_embeds(self) -> EmbeddingsTensor:
        return self.input_embeds

    @property
    def source_mask(self) -> IdsTensor:
        return None

    @property
    def target_mask(self) -> IdsTensor:
        return self.attention_mask

    def get_step_target(
        self, step: int, with_attention: bool = False
    ) -> ExpandedTargetIdsTensor | tuple[ExpandedTargetIdsTensor, ExpandedTargetIdsTensor]:
        tgt = self.input_ids[:, step]
        if with_attention:
            return tgt, self.attention_mask[:, step]
        return tgt

    @classmethod
    def from_batch(self, batch: Batch) -> "DecoderOnlyBatch":
        return DecoderOnlyBatch(
            encoding=batch.encoding,
            embedding=batch.embedding,
        )


def slice_batch_from_position(
    batch: DecoderOnlyBatch, curr_idx: int, alignments: list[tuple[int, int]] | None = None
) -> tuple[DecoderOnlyBatch, IdsTensor]:
    if len(alignments) > 0 and isinstance(alignments[0], list):
        alignments = alignments[0]
    truncate_idx = get_aligned_idx(curr_idx, alignments)
    tgt_ids = batch.target_ids[:, truncate_idx]
    return batch[:truncate_idx], tgt_ids
